import sqlite3

import lib.util as tool
import jieba
import json
import lib.sql as Sql
from lib import util
from lib.dbop import Database
import config as config
import torch


class DataCenter:
    def __init__(self, dbpath=None, ):
        self.db = Database()
        # 优化处理 没有传数据库的也能正常运行
        if dbpath is None:
            return
        self.db.init(dbpath)
        self.dicts: dict = {}

    def set_dict(self, dicts):
        self.dicts = dicts

    def load_conv(self, filepath):
        file = open(filepath, mode='r', encoding='utf-8')
        file_data = file.read()

        # print(file_data)
        file_list = file_data.split('\n')
        for i in range(len(file_list)):
            print("第{}条".format(i + 1))
            single = file_list[i]
            if single == 'E':
                continue
            sentence = single.split(' ')
            real = ''
            for j in range(len(sentence)):
                every = sentence[j]
                if every == 'M':
                    continue
                else:
                    real += sentence[j]
            real = real.replace("'", "\\'")
            # words = tool.cut_word(real)
            if real.isspace():
                continue
            # words, ensure_ascii = False)
            self.add_to_db(real)

        # 别忘记关闭文件
        file.close()

    def add_to_db(self, real):

        real_sql = Sql.insert_conv.format(real)
        try:
            self.db.insert(real_sql, True)
        except sqlite3.IntegrityError:
            # 文本不搞主键冲突 因为需要上下连续性 这里把主键换成Id了

            print("主键冲突 数据已经存在")
        except Exception as e:
            print("error->", e)

    # 获取所有的测试数据
    def get_data_set(self) -> dict:
        db_path = config.db_path
        files = tool.get_all_file(db_path)
        res = {}
        fanz_db = config.fanzxl_db_path
        fanzxl = self.get_single_data(fanz_db)
        xiaohuangji = self.get_single_data(config.xiaohuangji_db_path)
        res["fanzxl"] = fanzxl
        res["xiaohuangji"] = xiaohuangji
        return res

    # 传入文件名获取数据集
    def get_single_data(self, db_path) -> list:
        self.db.init(db_path)
        return self.db.select(sql=Sql.select_conv)

    # 读取dictionary的特定数据
    def gen_dict(self, file_path):
        file = open(file_path, mode='r', encoding='utf-8')
        file_data = file.read()
        json_data = json.loads(file_data)
        dicts = self.load_dict()
        for i in range(len(json_data)):
            char = json_data[i]["char"]

            # 过滤掉已经存在的字典
            try:
                if dicts[char]:
                    print("{}已存在".format(char))
                    continue
            except Exception as e:
                print("特殊字符 error->", e)
                continue

            print("正在插入{}条数据 word={}".format(i + 1, char))
            sql = Sql.insert_dict.format(char)
            try:
                self.db.insert(sql, True)
            except Exception as e:
                print("插入报错 error->", e)

    # 读取配置的字典数据
    def load_dict(self) -> dict:
        sql = Sql.select_dict
        self.db.init(config.dictionary_path)
        res = self.db.select(sql)
        transferred = self.transfer_dict(res)
        return transferred

    # 把数据库的查询结果转对象
    # {word:num,num:word} 各一份 方便查找
    def transfer_dict(self, dicts: list) -> dict:
        res = {}
        for i in range(len(dicts)):
            word = dicts[i]["word"]
            # 把0 去掉
            res[i + 1] = word
            res[word] = i + 1
        config.INPUT_DIM = len(dicts) + 5
        config.OUTPUT_DIM = len(dicts) + 5
        return res

    def transfer_word_2_number(self, word: str):
        res = []
        for i in range(len(word)):
            if word[i] not in self.dicts:
                # 实际训练最好别这么处理 会导致神经网络的长度不一致
                # self.add_words_which_not_in_dictionary(word[i])
                dict_num = -1
            else:
                dict_num = self.dicts[word[i]]
            res.append(dict_num)
        res = self.pad_zero_with_list(res)
        return res

    # 在数组中补0
    def pad_zero_with_list(self, data):
        length = config.DATA_LENGTH
        if len(data) > config.DATA_LENGTH:
            Exception('数据超过规定的长度{}'.format(length))
        if len(data) == config.DATA_LENGTH:
            return data
        needed = length - len(data)
        for i in range(needed):
            data.append(0)
        return data

    # 当遇到某个字符不在dicts 插入
    def add_words_which_not_in_dictionary(self, word):
        self.db.init(config.dictionary_path)
        sql = Sql.insert_dict.format(word)
        self.db.insert(sql, True)
        self.dicts = self.load_dict()

    def get_all_token(self, datas):
        token2idx = {}
        token2idx.update({"EOF": 0})
        idx = 1
        for sentence in datas:
            stc = sentence['sentence']
            for word in stc:
                if word not in token2idx.keys():
                    token2idx.update({word: idx})
                    idx += 1
        return token2idx

    def sentence2vec(self, sentence, token2idx):
        tokens = torch.LongTensor()
        for token in sentence:
            tokens = torch.cat([tokens, torch.LongTensor([token2idx[token]])])
        return tokens

    def load_token(self):
        fs = open(config.token_path, encoding='utf-8', mode='r')
        try:
            data = fs.read()
            token = json.loads(data)
            return token
        finally:
            fs.close()

    # 生成符合神经网络需要的测试数据
    def generate_data(self, src):
        length = config.BATCH_SIZE
        res = [src.numpy().tolist()]
        for i in range(length):
            res.append(self.generate_empty_data())
        return torch.LongTensor(res)

    def generate_empty_data(self):
        length = config.LONG_LIMIT
        res = []
        for i in range(length):
            res.append(0)
        return res
